// Copyright by LunaSec (owned by Refinery Labs, Inc)
//
// Licensed under the Business Source License v1.1
// (the "License"); you may not use this file except in compliance with the
// License. You may obtain a copy of the License at
//
// https://github.com/lunasec-io/lunasec/blob/master/licenses/BSL-LunaTrace.txt
//
// See the License for the specific language governing permissions and
// limitations under the License.
package vulnerability

import (
	"context"
	"fmt"
	"io/fs"
	"os"
	"path"
	"path/filepath"

	"github.com/Khan/genqlient/graphql"
	"github.com/rs/zerolog/log"
	"github.com/schollz/progressbar/v3"
	"github.com/vektah/gqlparser/v2/gqlerror"
	"go.uber.org/fx"

	"github.com/lunasec-io/lunasec/lunatrace/gogen/gql"
)

type AdvisoryIngester interface {
	Ingest(ctx context.Context, source string, advisoryLocation string) ([]string, error)
	IngestVulnerabilitiesFromSource(advisoryLocation, source, sourceRelativePath string) ([]string, error)
}

type FileAdvisoryIngesterParams struct {
	fx.In

	GQLClient graphql.Client
}

type FileAdvisoryIngester struct {
	FileAdvisoryIngesterParams
}

func NewFileIngester(params FileAdvisoryIngesterParams) AdvisoryIngester {
	return &FileAdvisoryIngester{
		FileAdvisoryIngesterParams: params,
	}
}

func (f FileAdvisoryIngester) Ingest(ctx context.Context, source string, advisoryLocation string) ([]string, error) {
	var insertedVulns []string

	fileInfo, err := os.Stat(advisoryLocation)
	if err != nil {
		log.Error().
			Err(err).
			Str("advisoryLocation", advisoryLocation).
			Msg("unable to get file info for file location")
		return insertedVulns, err
	}

	advisoryFiles, err := collectAdvisoryFiles(fileInfo.IsDir(), advisoryLocation)
	if err != nil {
		log.Error().
			Err(err).
			Str("advisoryLocation", advisoryLocation).
			Msg("unable to collect advisory files")
		return insertedVulns, err
	}

	bar := progressbar.Default(int64(len(advisoryFiles)))

	chunkSize := 100
	for i := 0; i < len(advisoryFiles); i += chunkSize {
		incrementedAmount := chunkSize
		endIdx := i + chunkSize

		if endIdx > len(advisoryFiles) {
			endIdx = len(advisoryFiles)
			incrementedAmount = len(advisoryFiles) - i
		}

		filenamesChunk := advisoryFiles[i:endIdx]

		var vulnerabilitiesInsert []*gql.Vulnerability_insert_input
		vulnerabilitiesInsert, err = generateBulkVulnerabilityInsertQuery(source, filenamesChunk)
		if err != nil {
			log.Warn().
				Err(err).
				Msg("failed to load vulnerabilities")
			continue
		}

		upsertedVulns, err := f.upsertVulnerabilities(ctx, vulnerabilitiesInsert)
		if err != nil {
			log.Warn().
				Err(err).
				Msg("failed to insert vulnerabilities")
			continue
		}

		insertedVulns = append(insertedVulns, upsertedVulns...)

		err = bar.Add(incrementedAmount)
		if err != nil {
			log.Warn().
				Err(err).
				Msg("error incrementing progress bar")
		}
	}
	return insertedVulns, nil
}

func collectAdvisoryFiles(isDir bool, advisoryLocation string) ([]string, error) {
	var files []string

	if !isDir {
		files = append(files, advisoryLocation)
		return files, nil
	}

	err := filepath.WalkDir(advisoryLocation, func(filepath string, d fs.DirEntry, err error) error {
		if d.IsDir() {
			return nil
		}

		if path.Ext(filepath) != ".json" {
			return nil
		}

		files = append(files, filepath)
		return nil
	})
	if err != nil {
		log.Error().
			Err(err).
			Msg("unable to load files from directory file location")
		return nil, err
	}
	return files, nil
}

func (f FileAdvisoryIngester) upsertVulnerabilities(
	ctx context.Context,
	vulnerabilitiesInsert []*gql.Vulnerability_insert_input,
) ([]string, error) {

	resp, err := gql.UpsertVulnerabilities(ctx, f.GQLClient, vulnerabilitiesInsert, vulnerabilityOnConflict)
	if err != nil {
		// TODO (cthompson) If there is an error that happens during an upsert, it will be difficult to determine
		// which of the vulnerabilities caused the error. Some better error handling here would help.
		if gqlErrorList, ok := err.(gqlerror.List); ok {
			var errorMsgs []string
			for _, gqlError := range gqlErrorList {
				errorMsgs = append(errorMsgs, fmt.Sprintf("%v", gqlError.Extensions))
			}
			log.Error().
				Err(err).
				Strs("context", errorMsgs).
				Msg("unable to insert vulnerability")
			return nil, err
		}
		log.Error().
			Err(err).
			Msg("unable to insert vulnerability")
		return nil, err
	}

	insertedIds := resp.GetInsert_vulnerability().GetReturning()

	var idStrings []string
	for _, insertId := range insertedIds {
		idStrings = append(idStrings, insertId.GetId().String())
	}
	return idStrings, nil
}

func (f FileAdvisoryIngester) IngestVulnerabilitiesFromSource(advisoryLocation, source, sourceRelativePath string) ([]string, error) {
	advisoryLocation, cleanup, err := ensureAdvisoriesExistFromSource(source, advisoryLocation)

	defer cleanup()
	if err != nil {
		return []string{}, err
	}

	if sourceRelativePath != "" {
		advisoryLocation = path.Join(advisoryLocation, sourceRelativePath)
	}

	ingestCtx := context.Background()
	return f.Ingest(ingestCtx, source, advisoryLocation)
}
